# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class SuperresNetwork(nn.Module):
    """TODO: Add docstring.

    https://arxiv.org/abs/1802.09987

    Input shape: B x 128 x 128 x 128
    Output shape: B x (high//low * 128) x (high//low * 128) x (high//low * 128)

    .. note::

        If you use this code, please cite the original paper in addition to Kaolin.
        
        .. code-block::

            @incollection{ODM,
                title = {Multi-View Silhouette and Depth Decomposition for High Resolution 3D Object Representation},
                author = {Smith, Edward and Fujimoto, Scott and Meger, David},
                booktitle = {Advances in Neural Information Processing Systems 31},
                editor = {S. Bengio and H. Wallach and H. Larochelle and K. Grauman and N. Cesa-Bianchi and R. Garnett},
                pages = {6479--6489},
                year = {2018},
                publisher = {Curran Associates, Inc.},
                url = {http://papers.nips.cc/paper/7883-multi-view-silhouette-and-depth-decomposition-for-high-resolution-3d-object-representation.pdf}
            }
    """

    def __init__(self, high, low):
        super(SuperresNetwork, self).__init__()
        self.ratio = high // low
        self.layer1 = nn.Sequential(
            nn.Conv2d(6, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128))
        self.inner_convs_1 = nn.ModuleList([
            nn.Conv2d(128, 128, kernel_size=3, padding=1) for i in range(16)])
        self.inner_bns_1 = nn.ModuleList(
            [nn.BatchNorm2d(128) for i in range(16)])
        self.inner_convs_2 = nn.ModuleList([
            nn.Conv2d(128, 128, kernel_size=3, padding=1) for i in range(16)])
        self.inner_bns_2 = nn.ModuleList([
            nn.BatchNorm2d(128) for i in range(16)])
        self.layer2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
        )
        sub_list = [nn.Conv2d(128, 128, kernel_size=3, padding=1),
                    nn.PixelShuffle(2)]
        i = 0
        for i in range(int(math.log(self.ratio, 2)) - 1):
            sub_list.append(nn.Conv2d(32, 128, kernel_size=3, padding=1))
            sub_list.append(nn.PixelShuffle(2))
        self.sub_list = nn.ModuleList(sub_list)
        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 6, kernel_size=1, padding=0),
        )

    def forward(self, x):
        x = self.layer1(x)
        temp = x.clone()
        for i in range(16):
            recall = self.inner_convs_1[i](x.clone())
            recall = self.inner_bns_1[i](recall)
            recall = F.relu(recall)
            recall = self.inner_convs_2[i](recall)
            recall = self.inner_bns_2[i](recall)
            recall = recall + temp
            temp = recall.clone()
        recall = self.layer2(recall)
        x = x + recall
        for i in range(int(math.log(self.ratio, 2))):
            x = self.sub_list[2 * i](x)
            x = self.sub_list[2 * i + 1](x)
        x = self.layer3(x)
        x = torch.sigmoid(x)
        return x
